# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from math import pi

import pytest
import torch

from kornia.feature.mkd import (
    COEFFS,
    EmbedGradients,
    ExplicitSpacialEncoding,
    MKDDescriptor,
    MKDGradients,
    SimpleKD,
    VonMisesKernel,
    Whitening,
    get_grid_dict,
    get_kron_order,
    spatial_kernel_embedding,
)

from testing.base import BaseTester


@pytest.mark.parametrize("ps", [5, 13, 25])
def test_get_grid_dict(ps):
    grid_dict = get_grid_dict(ps)
    param_keys = ["x", "y", "phi", "rho"]
    assert set(grid_dict.keys()) == set(param_keys)
    for k in param_keys:
        assert grid_dict[k].shape == (ps, ps)


@pytest.mark.parametrize("d1,d2", [(1, 1), (1, 2), (2, 1), (5, 6)])
def test_get_kron_order(d1, d2):
    out = get_kron_order(d1, d2)
    assert out.shape == (d1 * d2, 2)


def test_get_kron_order_values_and_dtype():
    d1, d2 = 2, 3
    expected_output = torch.tensor(
        [
            [0, 0],
            [0, 1],
            [0, 2],
            [1, 0],
            [1, 1],
            [1, 2],
        ],
        dtype=torch.int64,
    )

    actual_output = get_kron_order(d1, d2)
    assert actual_output.dtype == torch.int64
    assert torch.equal(actual_output, expected_output)


class TestMKDGradients(BaseTester):
    @pytest.mark.parametrize("ps", [5, 13, 25])
    def test_shape(self, ps, device):
        inp = torch.ones(1, 1, ps, ps).to(device)
        gradients = MKDGradients().to(device)
        out = gradients(inp)
        assert out.shape == (1, 2, ps, ps)

    @pytest.mark.parametrize("bs", [1, 5, 13])
    def test_batch_shape(self, bs, device):
        inp = torch.ones(bs, 1, 15, 15).to(device)
        gradients = MKDGradients().to(device)
        out = gradients(inp)
        assert out.shape == (bs, 2, 15, 15)

    def test_print(self, device):
        gradients = MKDGradients().to(device)
        gradients.__repr__()

    def test_toy(self, device):
        patch = torch.ones(1, 1, 6, 6).to(device).float()
        patch[0, 0, :, 3:] = 0
        gradients = MKDGradients().to(device)
        out = gradients(patch)
        expected_mags_1 = torch.Tensor([0, 0, 1.0, 1.0, 0, 0]).to(device)
        expected_mags = expected_mags_1.unsqueeze(0).repeat(6, 1)
        expected_oris_1 = torch.Tensor([-pi, -pi, 0, 0, -pi, -pi]).to(device)
        expected_oris = expected_oris_1.unsqueeze(0).repeat(6, 1)
        self.assert_close(out[0, 0, :, :], expected_mags, atol=1e-3, rtol=1e-3)
        self.assert_close(out[0, 1, :, :], expected_oris, atol=1e-3, rtol=1e-3)

    def test_gradcheck(self, device):
        batch_size, channels, height, width = 1, 1, 13, 13
        patches = torch.rand(batch_size, channels, height, width, device=device, dtype=torch.float64)

        def grad_describe(patches):
            mkd_grads = MKDGradients()
            mkd_grads.to(device)
            return mkd_grads(patches)

        self.gradcheck(grad_describe, (patches), nondet_tol=1e-4)


class TestVonMisesKernel(BaseTester):
    @pytest.mark.parametrize("ps", [5, 13, 25])
    def test_shape(self, ps, device):
        inp = torch.ones(1, 1, ps, ps).to(device)
        vm = VonMisesKernel(patch_size=ps, coeffs=[0.38214156, 0.48090413]).to(device)
        out = vm(inp)
        assert out.shape == (1, 3, ps, ps)

    @pytest.mark.parametrize("bs", [1, 5, 13])
    def test_batch_shape(self, bs, device):
        inp = torch.ones(bs, 1, 15, 15).to(device)
        vm = VonMisesKernel(patch_size=15, coeffs=[0.38214156, 0.48090413]).to(device)
        out = vm(inp)
        assert out.shape == (bs, 3, 15, 15)

    @pytest.mark.parametrize("coeffs", COEFFS.values())
    def test_coeffs(self, coeffs, device):
        inp = torch.ones(1, 1, 15, 15).to(device)
        vm = VonMisesKernel(patch_size=15, coeffs=coeffs).to(device)
        out = vm(inp)
        assert out.shape == (1, 2 * len(coeffs) - 1, 15, 15)

    def test_print(self, device):
        vm = VonMisesKernel(patch_size=32, coeffs=[0.38214156, 0.48090413]).to(device)
        vm.__repr__()

    def test_toy(self, device):
        patch = torch.ones(1, 1, 6, 6).float().to(device)
        patch[0, 0, :, 3:] = 0
        vm = VonMisesKernel(patch_size=6, coeffs=[0.38214156, 0.48090413]).to(device)
        out = vm(patch)
        expected = torch.ones_like(out[0, 0, :, :]).to(device)
        self.assert_close(out[0, 0, :, :], expected * 0.6182, atol=1e-3, rtol=1e-3)

        expected = torch.Tensor([0.3747, 0.3747, 0.3747, 0.6935, 0.6935, 0.6935]).to(device)
        expected = expected.unsqueeze(0).repeat(6, 1)
        self.assert_close(out[0, 1, :, :], expected, atol=1e-3, rtol=1e-3)

        expected = torch.Tensor([0.5835, 0.5835, 0.5835, 0.0000, 0.0000, 0.0000]).to(device)
        expected = expected.unsqueeze(0).repeat(6, 1)
        self.assert_close(out[0, 2, :, :], expected, atol=1e-3, rtol=1e-3)

    def test_gradcheck(self, device):
        batch_size, channels, ps = 1, 1, 13
        patches = torch.rand(batch_size, channels, ps, ps, device=device, dtype=torch.float64)

        def vm_describe(patches, ps=13):
            vmkernel = VonMisesKernel(patch_size=ps, coeffs=[0.38214156, 0.48090413]).double()
            vmkernel.to(device)
            return vmkernel(patches.double())

        self.gradcheck(vm_describe, (patches, ps), nondet_tol=1e-4)

    @pytest.mark.jit()
    def test_jit(self, device, dtype):
        B, C, H, W = 2, 1, 13, 13
        patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
        model = VonMisesKernel(patch_size=13, coeffs=[0.38214156, 0.48090413]).to(patches.device, patches.dtype).eval()
        model_jit = torch.jit.script(
            VonMisesKernel(patch_size=13, coeffs=[0.38214156, 0.48090413]).to(patches.device, patches.dtype).eval()
        )
        self.assert_close(model(patches), model_jit(patches))


class TestEmbedGradients(BaseTester):
    @pytest.mark.parametrize("ps,relative", [(5, True), (13, True), (25, True), (5, False), (13, False), (25, False)])
    def test_shape(self, ps, relative, device):
        inp = torch.ones(1, 2, ps, ps).to(device)
        emb_grads = EmbedGradients(patch_size=ps, relative=relative).to(device)
        out = emb_grads(inp)
        assert out.shape == (1, 7, ps, ps)

    @pytest.mark.parametrize("bs", [1, 5, 13])
    def test_batch_shape(self, bs, device):
        inp = torch.ones(bs, 2, 15, 15).to(device)
        emb_grads = EmbedGradients(patch_size=15, relative=True).to(device)
        out = emb_grads(inp)
        assert out.shape == (bs, 7, 15, 15)

    def test_print(self, device):
        emb_grads = EmbedGradients(patch_size=15, relative=True).to(device)
        emb_grads.__repr__()

    def test_toy(self, device):
        grads = torch.ones(1, 2, 6, 6).float().to(device)
        grads[0, 0, :, 3:] = 0
        emb_grads = EmbedGradients(patch_size=6, relative=True).to(device)
        out = emb_grads(grads)
        expected = torch.ones_like(out[0, 0, :, :3]).to(device)
        self.assert_close(out[0, 0, :, :3], expected * 0.3787, atol=1e-3, rtol=1e-3)
        self.assert_close(out[0, 0, :, 3:], expected * 0, atol=1e-3, rtol=1e-3)

    # TODO: review this test implementation
    # @pytest.mark.xfail(reason="RuntimeError: Jacobian mismatch for output 0 with respect to input 0,")
    def test_gradcheck(self, device):
        batch_size, channels, ps = 1, 2, 13
        patches = torch.rand(batch_size, channels, ps, ps, device=device, dtype=torch.float64)

        def emb_grads_describe(patches, ps=13):
            emb_grads = EmbedGradients(patch_size=ps, relative=True).double()
            emb_grads.to(device)
            return emb_grads(patches.double())

        self.gradcheck(emb_grads_describe, (patches, ps), nondet_tol=1e-4)

    @pytest.mark.jit()
    def test_jit(self, device, dtype):
        B, C, H, W = 2, 2, 13, 13
        patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
        model = EmbedGradients(patch_size=W, relative=True).to(patches.device, patches.dtype).eval()
        model_jit = torch.jit.script(
            EmbedGradients(patch_size=W, relative=True).to(patches.device, patches.dtype).eval()
        )
        self.assert_close(model(patches), model_jit(patches))


@pytest.mark.parametrize("kernel_type,d,ps", [("cart", 9, 9), ("polar", 25, 9), ("cart", 9, 16), ("polar", 25, 16)])
def test_spatial_kernel_embedding(kernel_type, ps, d):
    grids = get_grid_dict(ps)
    spatial_kernel = spatial_kernel_embedding(kernel_type, grids)
    assert spatial_kernel.shape == (d, ps, ps)


class TestExplicitSpacialEncoding(BaseTester):
    @pytest.mark.parametrize(
        "kernel_type,ps,in_dims", [("cart", 9, 3), ("polar", 9, 3), ("cart", 13, 7), ("polar", 13, 7)]
    )
    def test_shape(self, kernel_type, ps, in_dims, device):
        inp = torch.ones(1, in_dims, ps, ps).to(device)
        ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=ps, in_dims=in_dims).to(device)
        out = ese(inp)
        d_ = 9 if kernel_type == "cart" else 25
        assert out.shape == (1, d_ * in_dims)

    @pytest.mark.parametrize(
        "kernel_type,bs", [("cart", 1), ("cart", 5), ("cart", 13), ("polar", 1), ("polar", 5), ("polar", 13)]
    )
    def test_batch_shape(self, kernel_type, bs, device):
        inp = torch.ones(bs, 7, 15, 15).to(device)
        ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=15, in_dims=7).to(device)
        out = ese(inp)
        d_ = 9 if kernel_type == "cart" else 25
        assert out.shape == (bs, d_ * 7)

    @pytest.mark.parametrize("kernel_type", ["cart", "polar"])
    def test_print(self, kernel_type, device):
        ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=15, in_dims=7).to(device)
        ese.__repr__()

    def test_toy(self, device):
        inp = torch.ones(1, 2, 6, 6).to(device).float()
        inp[0, 0, :, :] = 0
        cart_ese = ExplicitSpacialEncoding(kernel_type="cart", fmap_size=6, in_dims=2).to(device)
        out = cart_ese(inp)
        out_part = out[:, :9]
        expected = torch.zeros_like(out_part).to(device)
        self.assert_close(out_part, expected, atol=1e-3, rtol=1e-3)

        polar_ese = ExplicitSpacialEncoding(kernel_type="polar", fmap_size=6, in_dims=2).to(device)
        out = polar_ese(inp)
        out_part = out[:, :25]
        expected = torch.zeros_like(out_part).to(device)
        self.assert_close(out_part, expected, atol=1e-3, rtol=1e-3)

    @pytest.mark.parametrize("kernel_type", ["cart", "polar"])
    def test_gradcheck(self, kernel_type, device):
        batch_size, channels, ps = 1, 2, 13
        patches = torch.rand(batch_size, channels, ps, ps, device=device, dtype=torch.float64)

        def explicit_spatial_describe(patches, ps=13):
            ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=ps, in_dims=2)
            ese.to(device)
            return ese(patches)

        self.gradcheck(explicit_spatial_describe, (patches, ps), nondet_tol=1e-4)

    @pytest.mark.jit()
    def test_jit(self, device, dtype):
        B, C, H, W = 2, 2, 13, 13
        patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
        model = (
            ExplicitSpacialEncoding(kernel_type="cart", fmap_size=W, in_dims=2).to(patches.device, patches.dtype).eval()
        )
        model_jit = torch.jit.script(
            ExplicitSpacialEncoding(kernel_type="cart", fmap_size=W, in_dims=2).to(patches.device, patches.dtype).eval()
        )
        self.assert_close(model(patches), model_jit(patches))


class TestWhitening(BaseTester):
    @pytest.mark.parametrize(
        "kernel_type,xform,output_dims",
        [
            ("cart", None, 3),
            ("polar", None, 3),
            ("cart", "lw", 7),
            ("polar", "lw", 7),
            ("cart", "pca", 9),
            ("polar", "pca", 9),
        ],
    )
    def test_shape(self, kernel_type, xform, output_dims, device):
        in_dims = 63 if kernel_type == "cart" else 175
        wh = Whitening(xform=xform, whitening_model=None, in_dims=in_dims, output_dims=output_dims).to(device)
        inp = torch.ones(1, in_dims).to(device)
        out = wh(inp)
        assert out.shape == (1, output_dims)

    @pytest.mark.parametrize("bs", [1, 3, 7])
    def test_batch_shape(self, bs, device):
        wh = Whitening(xform="lw", whitening_model=None, in_dims=175, output_dims=128).to(device)
        inp = torch.ones(bs, 175).to(device)
        out = wh(inp)
        assert out.shape == (bs, 128)

    def test_print(self, device):
        wh = Whitening(xform="lw", whitening_model=None, in_dims=175, output_dims=128).to(device)
        wh.__repr__()

    def test_toy(self, device):
        wh = Whitening(xform="lw", whitening_model=None, in_dims=175, output_dims=175).to(device)
        inp = torch.ones(1, 175).to(device).float()
        out = wh(inp)
        expected = torch.ones_like(inp).to(device) * 0.0756
        self.assert_close(out, expected, atol=1e-3, rtol=1e-3)

    def test_gradcheck(self, device):
        batch_size, in_dims = 1, 175
        patches = torch.rand(batch_size, in_dims, device=device, dtype=torch.float64)

        def whitening_describe(patches, in_dims=175):
            wh = Whitening(xform="lw", whitening_model=None, in_dims=in_dims).double()
            wh.to(device)
            return wh(patches.double())

        self.gradcheck(whitening_describe, (patches, in_dims), nondet_tol=1e-4)

    @pytest.mark.jit()
    def test_jit(self, device, dtype):
        batch_size, in_dims = 1, 175
        patches = torch.rand(batch_size, in_dims).to(device)
        model = Whitening(xform="lw", whitening_model=None, in_dims=in_dims).to(patches.device, patches.dtype).eval()
        model_jit = torch.jit.script(
            Whitening(xform="lw", whitening_model=None, in_dims=in_dims).to(patches.device, patches.dtype).eval()
        )
        self.assert_close(model(patches), model_jit(patches))


class TestMKDDescriptor(BaseTester):
    dims = {"cart": 63, "polar": 175, "concat": 238}

    @pytest.mark.parametrize(
        "ps,kernel_type", [(9, "concat"), (9, "cart"), (9, "polar"), (32, "concat"), (32, "cart"), (32, "polar")]
    )
    def test_shape(self, ps, kernel_type, device):
        mkd = MKDDescriptor(patch_size=ps, kernel_type=kernel_type, whitening=None).to(device)
        inp = torch.ones(1, 1, ps, ps).to(device)
        out = mkd(inp)
        assert out.shape == (1, self.dims[kernel_type])

    @pytest.mark.parametrize(
        "ps,kernel_type,whitening",
        [
            (9, "concat", "lw"),
            (9, "cart", "lw"),
            (9, "polar", "lw"),
            (9, "concat", "pcawt"),
            (9, "cart", "pcawt"),
            (9, "polar", "pcawt"),
        ],
    )
    def test_whitened_shape(self, ps, kernel_type, whitening, device):
        mkd = MKDDescriptor(patch_size=ps, kernel_type=kernel_type, whitening=whitening).to(device)
        inp = torch.ones(1, 1, ps, ps).to(device)
        out = mkd(inp)
        output_dims = min(self.dims[kernel_type], 128)
        assert out.shape == (1, output_dims)

    @pytest.mark.parametrize("bs", [1, 3, 7])
    def test_batch_shape(self, bs, device):
        mkd = MKDDescriptor(patch_size=19, kernel_type="concat", whitening=None).to(device)
        inp = torch.ones(bs, 1, 19, 19).to(device)
        out = mkd(inp)
        assert out.shape == (bs, 238)

    def test_print(self, device):
        mkd = MKDDescriptor(patch_size=32, whitening="lw", training_set="liberty", output_dims=128).to(device)
        mkd.__repr__()

    def test_toy(self, device):
        inp = torch.ones(1, 1, 6, 6).to(device).float()
        inp[0, 0, :, :] = 0
        mkd = MKDDescriptor(patch_size=6, kernel_type="concat", whitening=None).to(device)
        out = mkd(inp)
        out_part = out[0, -28:]
        expected = torch.zeros_like(out_part).to(device)
        self.assert_close(out_part, expected, atol=1e-3, rtol=1e-3)

    @pytest.mark.parametrize("whitening", [None, "lw", "pca"])
    def test_gradcheck(self, whitening, device):
        batch_size, channels, ps = 1, 1, 19
        patches = torch.rand(batch_size, channels, ps, ps, device=device, dtype=torch.float64)

        def mkd_describe(patches, patch_size=19):
            mkd = MKDDescriptor(patch_size=patch_size, kernel_type="concat", whitening=whitening).double()
            mkd.to(device)
            return mkd(patches.double())

        self.gradcheck(mkd_describe, (patches, ps), nondet_tol=1e-4)

    @pytest.mark.skip("neither dict, nor nn.ModuleDict works")
    @pytest.mark.jit()
    def test_jit(self, device, dtype):
        batch_size, channels, ps = 1, 1, 19
        patches = torch.rand(batch_size, channels, ps, ps).to(device)
        kt = "concat"
        wt = "lw"
        model = MKDDescriptor(patch_size=ps, kernel_type=kt, whitening=wt).to(patches.device, patches.dtype).eval()
        model_jit = torch.jit.script(
            MKDDescriptor(patch_size=ps, kernel_type=kt, whitening=wt).to(patches.device, patches.dtype).eval()
        )
        self.assert_close(model(patches), model_jit(patches))


class TestSimpleKD(BaseTester):
    dims = {"cart": 63, "polar": 175}

    @pytest.mark.parametrize("ps,kernel_type", [(9, "cart"), (9, "polar"), (32, "cart"), (32, "polar")])
    def test_shape(self, ps, kernel_type, device):
        skd = SimpleKD(patch_size=ps, kernel_type=kernel_type).to(device)
        inp = torch.ones(1, 1, ps, ps).to(device)
        out = skd(inp)
        assert out.shape == (1, min(128, self.dims[kernel_type]))

    @pytest.mark.parametrize("bs", [1, 3, 7])
    def test_batch_shape(self, bs, device):
        skd = SimpleKD(patch_size=19, kernel_type="polar").to(device)
        inp = torch.ones(bs, 1, 19, 19).to(device)
        out = skd(inp)
        assert out.shape == (bs, 128)

    def test_print(self, device):
        skd = SimpleKD(patch_size=19, kernel_type="polar").to(device)
        skd.__repr__()

    def test_gradcheck(self, device):
        batch_size, channels, ps = 1, 1, 19
        patches = torch.rand(batch_size, channels, ps, ps, device=device, dtype=torch.float64)

        def skd_describe(patches, patch_size=19):
            skd = SimpleKD(patch_size=ps, kernel_type="polar", whitening="lw").double()
            skd.to(device)
            return skd(patches.double())

        self.gradcheck(skd_describe, (patches, ps), nondet_tol=1e-4)
